package cloud.seri.gateway.security.oauth2

import cloud.seri.gateway.config.oauth2.OAuth2Properties

import org.slf4j.LoggerFactory
import org.springframework.boot.json.JsonParserFactory
import org.springframework.security.jwt.JwtHelper
import org.springframework.security.oauth2.common.OAuth2AccessToken
import org.springframework.security.oauth2.common.OAuth2RefreshToken
import org.springframework.security.oauth2.common.exceptions.InvalidTokenException
import org.springframework.security.oauth2.provider.token.AccessTokenConverter
import org.springframework.util.StringUtils

import javax.servlet.http.Cookie
import javax.servlet.http.HttpServletRequest
import javax.servlet.http.HttpServletResponse

import org.apache.http.conn.util.InetAddressUtils.isIPv4Address
import org.apache.http.conn.util.InetAddressUtils.isIPv6Address
import org.apache.http.conn.util.PublicSuffixMatcherLoader

/**
 * Helps with OAuth2 cookie handling.
 */
class OAuth2CookieHelper(private val oAuth2Properties: OAuth2Properties) {

    /**
     * Public suffix matcher (to strip private subdomains off cookie scope).
     */
    private var suffixMatcher = PublicSuffixMatcherLoader.getDefault()

    private val log = LoggerFactory.getLogger(OAuth2CookieHelper::class.java)

    /**
     * Used to parse JWT claims.
     */
    private val jsonParser = JsonParserFactory.getJsonParser()

    init {

        // Alternatively, always get an up-to-date list by passing an URL
    }

    /**
     * Create cookies using the provided values.
     *
     * @param request the request we are handling.
     * @param accessToken the access token and enclosed refresh token for our cookies.
     * @param rememberMe whether the user had originally checked "remember me".
     * @param result will get the resulting cookies set.
     */
    fun createCookies(
        request: HttpServletRequest,
        accessToken: OAuth2AccessToken,
        rememberMe: Boolean,
        result: OAuth2Cookies
    ) {
        val domain = getCookieDomain(request)
        log.debug("creating cookies for domain {}", domain)
        val accessTokenCookie = Cookie(ACCESS_TOKEN_COOKIE, accessToken.value)
        setCookieProperties(accessTokenCookie, request.isSecure, domain)
        log.debug("created access token cookie '{}'", accessTokenCookie.name)

        val refreshToken = accessToken.refreshToken
        val refreshTokenCookie = createRefreshTokenCookie(refreshToken, rememberMe)
        setCookieProperties(refreshTokenCookie, request.isSecure, domain)
        log.debug("created refresh token cookie '{}', age: {}", refreshTokenCookie.name, refreshTokenCookie
            .maxAge)

        result.setCookies(accessTokenCookie, refreshTokenCookie)
    }

    /**
     * Create a cookie out of the given refresh token.
     * Refresh token cookies contain the base64 encoded refresh token (a JWT token).
     * They also contain a hint whether the refresh token was for remember me or not.
     * If not, then the cookie will be prefixed by the timestamp it was created at followed by a pipe '|'.
     * This gives us the chance to expire session cookies regardless of the token duration.
     */
    private fun createRefreshTokenCookie(refreshToken: OAuth2RefreshToken, rememberMe: Boolean): Cookie {
        var maxAge = -1
        var name = SESSION_TOKEN_COOKIE
        val value = refreshToken.value
        if (rememberMe) {
            name = REFRESH_TOKEN_COOKIE
            // get expiration in seconds from the token's "exp" claim
            val exp = getClaim(refreshToken.value, AccessTokenConverter.EXP, Integer::class.java)
            if (exp != null) {
                val now = (System.currentTimeMillis() / 1000L).toInt()
                maxAge = exp.toInt() - now
                log.debug("refresh token valid for another {} secs", maxAge)
                // let cookie expire a bit earlier than the token to avoid race conditions
                maxAge -= REFRESH_TOKEN_EXPIRATION_WINDOW_SECS.toInt()
            }
        }
        val refreshTokenCookie = Cookie(name, value)
        refreshTokenCookie.maxAge = maxAge
        return refreshTokenCookie
    }

    /**
     * Checks if the refresh token session has expired.
     * Only makes sense for non-persistent cookies, i.e. when remember me was not checked.
     * The motivation for this is that we want to throw out a user after a while if he's inactive.
     * We cannot do this via refresh token validity because that one is also used for remember me.
     *
     * @param refreshCookie the refresh token cookie to check.
     * @return true, if the session is expired.
     */
    fun isSessionExpired(refreshCookie: Cookie): Boolean {
        if (isRememberMe(refreshCookie)) { // no session expiration for "remember me"
            return false
        }
        // read non-remember-me session length in secs
        val validity = oAuth2Properties.webClientConfiguration.sessionTimeoutInSeconds
        if (validity == null || validity < 0) { // no session expiration configured
            return false
        }
        // token creating timestamp in secs is missing, session does not expire
        val iat = getClaim(refreshCookie.value, "iat", Integer::class.java) ?: return false
        val now = (System.currentTimeMillis() / 1000L).toInt()
        val sessionDuration = now - iat.toInt()
        log.debug("session duration {} secs, will timeout at {}", sessionDuration, validity)
        return sessionDuration > validity // session has expired
    }

    /**
     * Retrieve the given claim from the given token.
     *
     * @param refreshToken the JWT token to examine.
     * @param claimName name of the claim to get.
     * @param clazz the [Class] we expect to find there.
     * @return the desired claim.
     * @throws InvalidTokenException if we cannot find the claim in the token or it is of wrong type.
     */
    private fun <T> getClaim(refreshToken: String, claimName: String, clazz: Class<T>): T? {
        val jwt = JwtHelper.decode(refreshToken)
        val claims = jwt.claims
        val claimsMap = jsonParser.parseMap(claims)
        val claimValue = claimsMap[claimName] ?: return null
        if (!clazz.isAssignableFrom(claimValue::class.java)) {
            throw InvalidTokenException("claim is not of expected type: $claimName")
        }
        return claimValue as T
    }

    /**
     * Set cookie properties of access and refresh tokens.
     *
     * @param cookie the cookie to modify.
     * @param isSecure whether it is coming from a secure request.
     * @param domain the domain for which the cookie is valid. If `null`, then will fall back to default.
     */
    private fun setCookieProperties(cookie: Cookie, isSecure: Boolean, domain: String?) {
        cookie.isHttpOnly = true
        cookie.path = "/"
        cookie.secure = isSecure // if the request comes per HTTPS set the secure option on the cookie
        if (domain != null) {
            cookie.domain = domain
        }
    }

    /**
     * Logs the user out by clearing all cookies.
     *
     * @param httpServletRequest the request containing the Cookies.
     * @param httpServletResponse the response used to clear them.
     */
    fun clearCookies(httpServletRequest: HttpServletRequest, httpServletResponse: HttpServletResponse) {
        val domain = getCookieDomain(httpServletRequest)
        for (cookieName in COOKIE_NAMES) {
            clearCookie(httpServletRequest, httpServletResponse, domain, cookieName)
        }
    }

    private fun clearCookie(
        httpServletRequest: HttpServletRequest,
        httpServletResponse: HttpServletResponse,
        domain: String?,
        cookieName: String
    ) {
        val cookie = Cookie(cookieName, "")
        setCookieProperties(cookie, httpServletRequest.isSecure, domain)
        cookie.maxAge = 0
        httpServletResponse.addCookie(cookie)
        log.debug("clearing cookie {}", cookie.name)
    }

    /**
     * Returns the top level domain of the server from the request. This is used to limit the Cookie
     * to the top domain instead of the full domain name.
     *
     * A lot of times, individual gateways of the same domain get their own subdomain but authentication
     * shall work across all subdomains of the top level domain.
     *
     * For example, when sending a request to `app1.domain.com`,
     * this returns `.domain.com`.
     *
     * @param request the HTTP request we received from the client.
     * @return the top level domain to set the cookies for.
     * Returns `null` if the domain is not under a public suffix (.com, .co.uk), e.g. for localhost.
     */
    private fun getCookieDomain(request: HttpServletRequest): String? {
        var domain = oAuth2Properties.webClientConfiguration.cookieDomain
        if (domain != null) {
            return domain
        }
        // if not explicitly defined, use top-level domain
        domain = request.serverName.toLowerCase()
        // strip off leading www.
        if (domain.startsWith("www.")) {
            domain = domain.substring(4)
        }
        // if it isn't an IP address
        if (!isIPv4Address(domain) && !isIPv6Address(domain)) {
            // strip off private subdomains, leaving public TLD only
            val suffix = suffixMatcher.getDomainRoot(domain)
            if (suffix != null && suffix != domain) {
                // preserve leading dot
                return ".$suffix"
            }
        }
        // no top-level domain, stick with default domain
        return null
    }

    /**
     * Strip our token cookies from the array.
     *
     * @param cookies the cookies we receive as input.
     * @return the new cookie array without our tokens.
     */
    internal fun stripCookies(cookies: Array<Cookie>): Array<Cookie> {
        val cc = CookieCollection(*cookies)
        return if (cc.removeAll(COOKIE_NAMES)) {
            cc.toArray()
        } else cookies
    }

    companion object {
        /**
         * Name of the access token cookie.
         */
        const val ACCESS_TOKEN_COOKIE = OAuth2AccessToken.ACCESS_TOKEN
        /**
         * Name of the refresh token cookie in case of remember me.
         */
        const val REFRESH_TOKEN_COOKIE = OAuth2AccessToken.REFRESH_TOKEN
        /**
         * Name of the session-only refresh token in case the user did not check remember me.
         */
        const val SESSION_TOKEN_COOKIE = "session_token"
        /**
         * The names of the Cookies we set.
         */
        private val COOKIE_NAMES = listOf(ACCESS_TOKEN_COOKIE, REFRESH_TOKEN_COOKIE, SESSION_TOKEN_COOKIE)
        /**
         * Number of seconds to expire refresh token cookies before the enclosed token expires.
         * This makes sure we don't run into race conditions where the cookie is still there but
         * expires while we process it.
         */
        private const val REFRESH_TOKEN_EXPIRATION_WINDOW_SECS = 3L

        fun getAccessTokenCookie(request: HttpServletRequest): Cookie? {
            return getCookie(request, ACCESS_TOKEN_COOKIE)
        }

        fun getRefreshTokenCookie(request: HttpServletRequest): Cookie? {
            var cookie = getCookie(request, REFRESH_TOKEN_COOKIE)
            if (cookie == null) {
                cookie = getCookie(request, SESSION_TOKEN_COOKIE)
            }
            return cookie
        }

        /**
         * Get a cookie by name from the given servlet request.
         *
         * @param request the request containing the cookie.
         * @param cookieName the case-sensitive name of the cookie to get.
         * @return the resulting [Cookie]; or `null`, if not found.
         */
        private fun getCookie(request: HttpServletRequest, cookieName: String): Cookie? {
            if (request.cookies != null) {
                for (cookie in request.cookies) {
                    if (cookie.name == cookieName) {
                        val value = cookie.value
                        if (StringUtils.hasText(value)) {
                            return cookie
                        }
                    }
                }
            }
            return null
        }

        /**
         * Returns true if the refresh token cookie was set with remember me checked.
         * We can recognize this by the name of the cookie.
         *
         * @param refreshTokenCookie the cookie holding the refresh token.
         * @return true, if it was set persistently (i.e. for "remember me").
         */
        fun isRememberMe(refreshTokenCookie: Cookie): Boolean {
            return refreshTokenCookie.name == REFRESH_TOKEN_COOKIE
        }

        /**
         * Extracts the refresh token from the refresh token cookie.
         * Since we encode additional information into the cookie, this needs to be called to get
         * hold of the enclosed JWT.
         *
         * @param refreshCookie the cookie we store the value in.
         * @return the refresh JWT from the cookie.
         */
        fun getRefreshTokenValue(refreshCookie: Cookie): String {
            val value = refreshCookie.value
            val i = value.indexOf('|')
            return if (i > 0) {
                value.substring(i + 1)
            } else value
        }
    }
}
